// @HEADER
//*********************************************************************//
//  SiYuan: A numerical PDE solver                                     //
//  Copyright (2022) YUAN Xi                                           //
//  This Software is released under the BSD 2-Clause license detailed  //
//  in the file "LICENSE" in the top-level SiYuan directory            //
//*********************************************************************//
// @HEADER

#ifndef _ShallowWater_EquationSet_impl_hpp_
#define _ShallowWater_EquationSet_impl_hpp_

#include "Teuchos_ParameterList.hpp"
#include "Teuchos_StandardParameterEntryValidators.hpp"
#include "Teuchos_RCP.hpp"
#include "Teuchos_Assert.hpp"
#include "Phalanx_DataLayout_MDALayout.hpp"
#include "Phalanx_FieldManager.hpp"

#include "Panzer_IntegrationRule.hpp"
#include "Panzer_BasisIRLayout.hpp"

// ***********************************************************************
template <typename EvalT>
SiYuan::EquationSet_ShallowWater<EvalT>::
EquationSet_ShallowWater(const Teuchos::RCP<Teuchos::ParameterList>& params,
    const int& default_integration_order,
    const panzer::CellData& cell_data,
    const Teuchos::RCP<panzer::GlobalData>& global_data,
    const bool build_transient_support) :
    EquationSet<EvalT> (params,default_integration_order,cell_data,global_data,build_transient_support ) 
{
  std::string model_id = params->get<std::string>("Model ID");
  cellDim = cell_data.baseCellDimension();
  m_dof_names[0] = "h";            // Water depth H = h + \zeta
  m_dof_names[1] = "hu";            // H*u_x
  m_dof_names[2] = "hv";            // H*u_y
  std::string basis_type = "HGrad";
  int basis_order = params->get<int>("Basis Order");
  int integration_order = params->get<int>("Integration Order");
  g_ = params->get<double>("Gravitational Acceleration");

  // ********************
  // Setup DOFs and closure models
  // ********************
  {
    this->addDOF(m_dof_names[0],basis_type,basis_order,integration_order);
  //  this->addDOFGrad(m_dof_names[0]);
    this->addDOFTimeDerivative(m_dof_names[0]);

    this->addDOF(m_dof_names[1],basis_type,basis_order,integration_order);
  //  this->addDOFGrad(m_dof_names[1]);
    this->addDOFTimeDerivative(m_dof_names[1]);  

    if( cellDim>1 ) {  // 2D problem
      this->addDOF(m_dof_names[2],basis_type,basis_order,integration_order);
    //  this->addDOFGrad(m_dof_names[2]);
      this->addDOFTimeDerivative(m_dof_names[2]); 
    }
  }

  this->addClosureModel(model_id);
  this->setupDOFs();
}

// ***********************************************************************
template <typename EvalT>
void SiYuan::EquationSet_ShallowWater<EvalT>::
buildAndRegisterEquationSetEvaluators(PHX::FieldManager<panzer::Traits>& fm,
    const panzer::FieldLibrary& /* fl */,
    const Teuchos::ParameterList& ) const
{
  using panzer::EvaluatorStyle;

  Teuchos::RCP<panzer::IntegrationRule> ir = this->getIntRuleForDOF(m_dof_names[0]); 
  Teuchos::RCP<panzer::BasisIRLayout> basis = this->getBasisIRLayoutForDOF(m_dof_names[0]);

  // Residual
  {
    Teuchos::RCP<PHX::Evaluator<panzer::Traits>> top =Teuchos:: rcp(new
      SiYuan::Residual_ShallowWater<EvalT, panzer::Traits>(EvaluatorStyle::EVALUATES,
        *basis, *ir, g_ ) );
    this->template registerEvaluator<EvalT>(fm, top);
  }

}

// ***********************************************************************

// ***********************************************************************
/////////////////////////////////////////////////////////////////////////////
//
//  Main Constructor: Residual_ShallowWater
//
/////////////////////////////////////////////////////////////////////////////
template<typename EvalT, typename Traits>
SiYuan::Residual_ShallowWater<EvalT, Traits>::
Residual_ShallowWater(
    const panzer::EvaluatorStyle&   evalStyle,
    const panzer::BasisIRLayout&    basis,
    const panzer::IntegrationRule&  ir,
    const double& g )
    : evalStyle_(evalStyle), basisName_(basis.name()), g_(g)
{
    cellDim = basis.dimension();

  //  bed_depth_ = PHX::MDField<const double, panzer::Cell, panzer::BASIS>("BedDepth", basis.functional);
  //  this->addDependentField(bed_depth_);
    bed_grad_ = PHX::MDField<double, panzer::Cell, panzer::IP, panzer::Dim >("BedGrad", ir.dl_vector);
  //  this->addEvaluatedField(bed_grad_);
    // Make this unshared so that it is not overwritten
    this->addUnsharedField(bed_grad_.fieldTag().clone());

//std::cout << bed_grad_.extent(0) << ","<< bed_grad_.extent(1) << std::endl;
    h_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP>("h", ir.dl_scalar);
    this->addDependentField(h_);
    hu_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP>("hu", ir.dl_scalar);
    this->addDependentField(hu_);
    if( cellDim>1 ) {
      hv_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP>("hv", ir.dl_scalar);
      this->addDependentField(hv_);
    }

    dhdt_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP>("DXDT_h", ir.dl_scalar);
    this->addDependentField(dhdt_);
    dhudt_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP>("DXDT_hu", ir.dl_scalar);
    this->addDependentField(dhudt_);
    if( cellDim>1 ) {
      dhvdt_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP>("DXDT_hv", ir.dl_scalar);
      this->addDependentField(dhvdt_);
    }

    fields_.resize(cellDim+1);
  //  scratch_offsets_.resize(names.size());
    
    // Create the field that we're either contributing to or evaluating
    // (storing).
    fields_[0] = PHX::MDField<ScalarT, panzer::Cell, panzer::BASIS>("RESIDUAL_h",basis.functional);
    fields_[1] = PHX::MDField<ScalarT, panzer::Cell, panzer::BASIS>("RESIDUAL_hu",basis.functional);
    if( cellDim>1 ) 
      fields_[2] = PHX::MDField<ScalarT, panzer::Cell, panzer::BASIS>("RESIDUAL_hv",basis.functional);
    if (evalStyle == panzer::EvaluatorStyle::CONTRIBUTES)
      for( int i=0; i<cellDim+1; ++i)
        this->addContributedField(fields_[i]);
    else // if (evalStyle == EvaluatorStyle::EVALUATES)
      for( int i=0; i<cellDim+1; ++i)
        this->addEvaluatedField(fields_[i]);

    // Set the name of this object.
    std::string n("ShallowWater (");
    if (evalStyle == panzer::EvaluatorStyle::CONTRIBUTES)
      n += "CONTRIBUTES)";
    else // if (evalStyle == EvaluatorStyle::EVALUATES)
      n += "EVALUATES)";
  //  n += "):  " + field_.fieldTag().name();
    this->setName(n);
} // end of Main Constructor

/////////////////////////////////////////////////////////////////////////////
//
//  postRegistrationSetup()
//
/////////////////////////////////////////////////////////////////////////////
template<typename EvalT, typename Traits>
void
SiYuan::Residual_ShallowWater<EvalT, Traits>:: postRegistrationSetup(
    typename Traits::SetupData sd, PHX::FieldManager<Traits>& fm)
{
    basisIndex_ = panzer::getBasisIndex(basisName_, (*sd.worksets_)[0], this->wda);
  //  this->utils.setFieldData(bed_grad_,fm);
  //  constant.deep_copy(value);

  //  Teuchos::RCP<PHX::DataLayout> dl = Teuchos::rcp(new PHX::MDALayout<panzer::Cell, panzer::IP, panzer::Dim>(numCells,numQP,numDims));

} // end of postRegistrationSetup()

/////////////////////////////////////////////////////////////////////////////
//
//  evaluateFields()
//
/////////////////////////////////////////////////////////////////////////////
template<typename EvalT, typename Traits>
void
SiYuan::Residual_ShallowWater<EvalT, Traits>::
evaluateFields( typename Traits::EvalData workset )
{
    using Kokkos::parallel_for;
    using Kokkos::RangePolicy;
    using PHX::Device;

    typedef Intrepid2::FunctionSpaceTools<PHX::exec_space> FST;

//std::cout << "RES:" << PHX::print<EvalT>()  << std::endl;
    // Grab the basis information.
    const panzer::BasisValues2<double>& bv = *this->wda(workset).bases[basisIndex_];

    PHX::MDField<double, panzer::Cell, panzer::BASIS, panzer::IP, panzer::Dim>
          weightedGradBasis = bv.weighted_grad_basis;
    PHX::MDField<double, panzer::Cell, panzer::BASIS, panzer::IP>
          weightedBasis = bv.weighted_basis_scalar;

    int numDims = weightedGradBasis.extent(3);
    int numBases= weightedGradBasis.extent(1);
    int numQP = weightedGradBasis.extent(2);

    for( int cell=0; cell<workset.num_cells; cell++ )
    {
      for (int qp(0); qp < numQP; ++qp) {
        for( int dim=0; dim<cellDim; ++dim) {
      //    bed_grad_(cell, qp, dim) = 0.0;
          for (int basis(0); basis < numBases; ++basis) {
        //    bed_grad_(cell, qp, dim) += GradBasis(cell, basis, qp, dim) * bed_depth_(cell, basis);
          }
        }
      }
    }

    if( cellDim==1 ) {
      for( int cell=0; cell<workset.num_cells; cell++ )
      {
        for (int basis(0); basis < numBases; ++basis)
        {
          if (evalStyle_ == panzer::EvaluatorStyle::EVALUATES) {
            fields_[0](cell, basis) = 0.0;
            fields_[1](cell, basis) = 0.0;
          }
          for (int qp(0); qp < numQP; ++qp) {
            fields_[0](cell, basis) += dhdt_(cell, qp) * weightedBasis(cell, basis, qp) - hu_(cell, qp)* weightedGradBasis(cell, basis, qp, 0);
            fields_[1](cell, basis) += dhudt_(cell, qp) * weightedBasis(cell, basis, qp) - 
                ( 0.5*g_ *h_(cell,qp)*h_(cell,qp) + hu_(cell,qp)* hu_(cell, qp)/ h_(cell,qp) ) * weightedGradBasis(cell, basis, qp, 0);
//                + g_*h_(cell,qp) * bed_grad_(cell, qp, 0) * weightedBasis(cell, basis, qp) ;
          }
        }
      }
    } else {
      for( int cell=0; cell<workset.num_cells; cell++ )
      {
        for (int basis(0); basis < numBases; ++basis)
        {
          if (evalStyle_ == panzer::EvaluatorStyle::EVALUATES) {
            fields_[0](cell, basis) = 0.0;
            fields_[1](cell, basis) = 0.0;
            fields_[2](cell, basis) = 0.0;
          }
          for (int qp(0); qp < numQP; ++qp) {
       //     std::cout << cell << ", " << qp << ",  " << hu_(cell,qp) << std::endl;
            fields_[0](cell, basis) += dhdt_(cell, qp) * weightedBasis(cell, basis, qp) 
                                    - hu_(cell, qp)* weightedGradBasis(cell, basis, qp, 0)
                                    - hv_(cell, qp)* weightedGradBasis(cell, basis, qp, 1);
            fields_[1](cell, basis) += dhudt_(cell, qp) * weightedBasis(cell, basis, qp) - 
                ( 0.5*g_ *h_(cell,qp)*h_(cell,qp) + hu_(cell,qp)* hu_(cell, qp)/ h_(cell,qp) ) * weightedGradBasis(cell, basis, qp, 0) - 
                hu_(cell,qp)* hv_(cell, qp)/ h_(cell,qp) * weightedGradBasis(cell, basis, qp, 1);
           //     + g_*h_(cell,qp) * bed_grad_(cell, qp, 0) * weightedBasis(cell, basis, qp) ;
            fields_[2](cell, basis) += dhvdt_(cell, qp) * weightedBasis(cell, basis, qp) - 
                ( 0.5*g_ *h_(cell,qp)*h_(cell,qp) + hv_(cell,qp)* hv_(cell, qp)/ h_(cell,qp) ) * weightedGradBasis(cell, basis, qp, 1) - 
                hu_(cell,qp)* hv_(cell, qp)/ h_(cell,qp) * weightedGradBasis(cell, basis, qp, 0) ;
           //     + g_*h_(cell,qp) * bed_grad_(cell, qp, 1) * weightedBasis(cell, basis, qp) ;
          }
        }
      }
    }

  //  std::cout << h_(0,0) << "," << workset.num_cells << std::endl;

} // end of evaluateFields()

#endif 
